Commit 2d5a25cd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch '0.6.2-w8a8' into 'v0.6.2-dev'

0.6.2 w8a8

See merge request dcutoolkit/deeplearing/vllm!43
parents 0dc55ec0 26513bb5
...@@ -9,7 +9,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -9,7 +9,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表 ## 支持模型结构列表
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ | | 结构 | 模型 | FP16/BF16 | AWQ | GPTQ |
| :------: | :------: | :------: | :------: | | :------: | :------: | :------: | :------: |:------: |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,deepseek | Yes | Yes | Yes | | LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,deepseek | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | | QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | Yes | | Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | Yes |
......
This diff is collapsed.
...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform ...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
try: try:
from lmslim import quant_ops from lmslim import quant_ops
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -706,9 +706,9 @@ def cutlass_scaled_mm(a: torch.Tensor, ...@@ -706,9 +706,9 @@ def cutlass_scaled_mm(a: torch.Tensor,
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out # return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor, def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
......
...@@ -4,12 +4,12 @@ import torch ...@@ -4,12 +4,12 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip from vllm.utils import is_hip,W8a8GetCacheJSON
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
W8A8_TRITONJSON=W8a8GetCacheJSON()
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm # cutlass is not supported on Rocm
...@@ -200,12 +200,37 @@ def apply_int8_linear( ...@@ -200,12 +200,37 @@ def apply_int8_linear(
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
if w8a8_strategy==1: if w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=weight.shape[1]
if f"{m}_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
#print("json files:",best_config)
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
if m<64:
m_= 32
elif m<128:
m_=64
elif m<256:
m_=128
elif m<512:
m_=256
elif m<1024:
m_=512
else:
m_=1024
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{n}_{k}"]
else:
best_config=None
print("config not found!")
return ops.triton_scaled_mm(x_q, return ops.triton_scaled_mm(x_q,
weight, weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=bias) bias=bias,best_config=best_config)
elif w8a8_strategy==2: elif w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q, return ops.cutlass_scaled_mm(x_q,
weight, weight,
......
...@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip,W8a8GetCacheJSON
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quant_config, quant_config,
lora_config=lora_config, lora_config=lora_config,
prefix="model") prefix="model")
self.tritonsingleton= W8a8GetCacheJSON()
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
...@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def forward( def forward(
self, self,
...@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors": if self.quant_method == "compressed_tensors":
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
...@@ -656,15 +660,39 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -656,15 +660,39 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"mlp.down_proj.weight", "mlp.down_proj.weight",
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches and "scale" not in layername:
weight_data =params_dict[layername] weight_data =params_dict[layername]
k=weight_data.shape[0] n=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight) weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops._int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state # make sure to leave KV cache scale factors in a known good (dummy) state
......
...@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of from vllm.utils import is_list_of,W8a8GetCacheJSON
from .utils import flatten_bn, is_pp_missing_parameter, make_layers from .utils import flatten_bn, is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def _get_image_input_type( def _get_image_input_type(
self, self,
...@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
"mlp.c_proj.weight", "mlp.c_proj.weight",
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches and "scale" not in layername:
weight_data =params_dict[layername] weight_data =params_dict[layername]
k=weight_data.shape[0] n=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight) weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops._int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
...@@ -16,6 +16,7 @@ import threading ...@@ -16,6 +16,7 @@ import threading
import uuid import uuid
import warnings import warnings
import weakref import weakref
import json
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, ensure_future
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
...@@ -1334,3 +1335,88 @@ class AtomicCounter: ...@@ -1334,3 +1335,88 @@ class AtomicCounter:
@property @property
def value(self): def value(self):
return self._value return self._value
class W8a8GetCacheJSON:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(W8a8GetCacheJSON, cls).__new__(cls, *args, **kwargs)
cls._instance._initialize()
return cls._instance
def _initialize(self):
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', './cache'))
self.triton_json_dict=[]
def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict:
return configs_dict[f"{M}_{N}_{K}"]
else:
return None
def get_triton_cache_tune(self,file_path,n,k):
#tuning的时候使用,当文件不存在时候,则创建文件夹
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
folder_path = os.path.dirname(file_path)
os.makedirs(folder_path, exist_ok=True)
cachedata = {}
# 写入空数据到新的JSON文件
with open(file_path, 'w') as file:
json.dump(cachedata, file)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'SPLIT_K': int(sub_value["SPLIT_K"]),
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps'])
}
configs_dict[configs_key]=configs_value
return configs_dict
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'SPLIT_K': int(sub_value["SPLIT_K"]),
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps'])
}
configs_dict[configs_key]=configs_value
return configs_dict
def get_w8a8json_name(self,n,k):
device_name = current_platform.get_device_name().replace(" ", "_")
return self.triton_json_dir+f"/W8A8_{n}_{k}_DCU{device_name}.json"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment