Commit 35a8304d authored by zhuwenwen's avatar zhuwenwen
Browse files

添加w8a8 rocblas非融合支持

parent 5a9c236d
...@@ -9,25 +9,16 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -9,25 +9,16 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表 ## 支持模型结构列表
| 结构 | 模型 | 模型并行 | FP16 | | 结构 | 模型 | 模型并行 | FP16 |
| :----------: | :----------: | :------: | :--: | | :------: | :------: | :------: | :------: |
| LlamaForCausalLM | LLaMA | Yes | Yes | | LlamaForCausalLM | LLaMA、LLaMA-2、LLaMA-3、Codellama、deepseek、Yi | Yes | Yes |
| LlamaForCausalLM | LLaMA-2 | Yes | Yes |
| LlamaForCausalLM | LLaMA-3 | Yes | Yes |
| LlamaForCausalLM | Codellama | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes | | QWenLMHeadModel | QWen | Yes | Yes |
| Qwen2ForCausalLM | QWen1.5 | Yes | Yes | | Qwen2ForCausalLM | QWen1.5、CodeQwen1.5、QWen2 | Yes | Yes |
| Qwen2ForCausalLM | CodeQwen1.5 | Yes | Yes | | ChatGLMModel | chatglm2、chatglm3 | Yes | Yes |
| Qwen2ForCausalLM | QWen2 | Yes | Yes | | BaiChuanForCausalLM | Baichuan、Baichuan2 | Yes | Yes |
| ChatGLMModel | chatglm2 | Yes | Yes | | BloomForCausalLM     | BLOOM        | Yes | Yes |
| ChatGLMModel | chatglm3 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2 | Yes | Yes |
|    BloomForCausalLM       |    BLOOM          |   Yes    | Yes  |
| InternLMForCausalLM | InternLM | Yes | Yes | | InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes | | InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| LlamaForCausalLM | deepseek | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes | | DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| LlamaForCausalLM | Yi | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes | | MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
......
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
if __name__ == '__main__': # Sample prompts.
# Sample prompts. prompts = [
prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True) llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
...@@ -15,4 +15,4 @@ torch == 2.3.0 ...@@ -15,4 +15,4 @@ torch == 2.3.0
triton == 2.1.0 triton == 2.1.0
flash_attn == 2.6.1 flash_attn == 2.6.1
xformers == 0.0.25 xformers == 0.0.25
lmslim == 0.1.1 lmslim == 0.1.2
\ No newline at end of file \ No newline at end of file
...@@ -615,18 +615,19 @@ def cutlass_scaled_mm(a: torch.Tensor, ...@@ -615,18 +615,19 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) # assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) # assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[ # assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype # 1] and bias.dtype == out_dtype
m = a.shape[0] # m = a.shape[0]
n = b.shape[1] # n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device) # out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out # return out
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def cutlass_scaled_mm_azp(a: torch.Tensor, def cutlass_scaled_mm_azp(a: torch.Tensor,
......
...@@ -407,7 +407,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -407,7 +407,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None and quant_method!="awq" and quant_method!="gptq": if quant_method is not None and quant_method != "awq" and quant_method != "gptq" and quant_method != "compressed_tensors":
# When quant methods need to process weights after loading # When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters # (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the # to be on the global target device. This scope is for the
......
...@@ -28,8 +28,8 @@ def get_model_architecture( ...@@ -28,8 +28,8 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '0': if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '1' os.environ['FA_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
......
...@@ -639,6 +639,23 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -639,6 +639,23 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
weight_data =params_dict[layername]
k=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state # make sure to leave KV cache scale factors in a known good (dummy) state
......
...@@ -1091,3 +1091,19 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -1091,3 +1091,19 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
lay_key_words = [
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.c_proj.weight",
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
weight_data =params_dict[layername]
k=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight)
...@@ -543,3 +543,19 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -543,3 +543,19 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
weight_data =params_dict[layername]
k=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment