Commit 217ee621 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev

parents f0021a4d 3f78216a
...@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip,W8a8GetCacheJSON
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quant_config, quant_config,
lora_config=lora_config, lora_config=lora_config,
prefix="model") prefix="model")
self.tritonsingleton= W8a8GetCacheJSON()
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
...@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def forward( def forward(
self, self,
...@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors": if self.quant_method == "compressed_tensors":
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
...@@ -656,14 +660,38 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -656,14 +660,38 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"mlp.down_proj.weight", "mlp.down_proj.weight",
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches and "scale" not in layername:
weight_data =params_dict[layername] weight_data =params_dict[layername]
k=weight_data.shape[0] n=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight) #rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops._int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
...@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of from vllm.utils import is_list_of,W8a8GetCacheJSON
from .utils import flatten_bn, is_pp_missing_parameter, make_layers from .utils import flatten_bn, is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def _get_image_input_type( def _get_image_input_type(
self, self,
...@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
"mlp.c_proj.weight", "mlp.c_proj.weight",
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches and "scale" not in layername:
weight_data =params_dict[layername] weight_data =params_dict[layername]
k=weight_data.shape[0] n=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight) #rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops._int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
import psutil
import torch import torch
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
...@@ -9,6 +10,10 @@ class CpuPlatform(Platform): ...@@ -9,6 +10,10 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return "cpu" return "cpu"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
......
...@@ -83,6 +83,11 @@ class Platform: ...@@ -83,6 +83,11 @@ class Platform:
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`. """A device-specific wrapper of `torch.inference_mode`.
......
...@@ -29,3 +29,8 @@ class RocmPlatform(Platform): ...@@ -29,3 +29,8 @@ class RocmPlatform(Platform):
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id) return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory
...@@ -9,6 +9,10 @@ class TpuPlatform(Platform): ...@@ -9,6 +9,10 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
......
...@@ -45,7 +45,7 @@ class CustomCacheManager(FileCacheManager): ...@@ -45,7 +45,7 @@ class CustomCacheManager(FileCacheManager):
self.cache_dir = os.getenv("TRITON_CACHE_DIR", self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir() "").strip() or default_cache_dir()
if self.cache_dir: if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}" # self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key) self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock") self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True) os.makedirs(self.cache_dir, exist_ok=True)
......
...@@ -16,6 +16,7 @@ import threading ...@@ -16,6 +16,7 @@ import threading
import uuid import uuid
import warnings import warnings
import weakref import weakref
import json
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, ensure_future
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
...@@ -119,6 +120,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS" ...@@ -119,6 +120,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""
GiB_bytes = 1 << 30 GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB).""" """The number of bytes in one gibibyte (GiB)."""
...@@ -1331,3 +1335,88 @@ class AtomicCounter: ...@@ -1331,3 +1335,88 @@ class AtomicCounter:
@property @property
def value(self): def value(self):
return self._value return self._value
class W8a8GetCacheJSON:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(W8a8GetCacheJSON, cls).__new__(cls, *args, **kwargs)
cls._instance._initialize()
return cls._instance
def _initialize(self):
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', './cache'))
self.triton_json_dict=[]
def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict:
return configs_dict[f"{M}_{N}_{K}"]
else:
return None
def get_triton_cache_tune(self,file_path,n,k):
#tuning的时候使用,当文件不存在时候,则创建文件夹
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
folder_path = os.path.dirname(file_path)
os.makedirs(folder_path, exist_ok=True)
cachedata = {}
# 写入空数据到新的JSON文件
with open(file_path, 'w') as file:
json.dump(cachedata, file)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'SPLIT_K': int(sub_value["SPLIT_K"]),
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps'])
}
configs_dict[configs_key]=configs_value
return configs_dict
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'SPLIT_K': int(sub_value["SPLIT_K"]),
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps'])
}
configs_dict[configs_key]=configs_value
return configs_dict
def get_w8a8json_name(self,n,k):
device_name = current_platform.get_device_name().replace(" ", "_")
return self.triton_json_dir+f"/W8A8_{n}_{k}_DCU{device_name}.json"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment