Unverified Commit e3e79e9e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Implement AWQ quantization support for LLaMA (#1032)


Co-authored-by: default avatarRobert Irvine <robert@seamlessml.com>
Co-authored-by: default avatarroot <rirv938@gmail.com>
Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
Co-authored-by: default avatarjulian-q <julianhquevedo@gmail.com>
parent b9fe4616
...@@ -173,3 +173,7 @@ cython_debug/ ...@@ -173,3 +173,7 @@ cython_debug/
# Sphinx documentation # Sphinx documentation
_build/ _build/
# vim swap files
*.swo
*.swp
...@@ -18,6 +18,7 @@ def main(args: argparse.Namespace): ...@@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
llm = LLM( llm = LLM(
model=args.model, model=args.model,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size, max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len, max_num_batched_tokens=args.batch_size * args.input_len,
...@@ -63,19 +64,28 @@ def main(args: argparse.Namespace): ...@@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1, parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.') help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3, parser.add_argument('--num-iters',
type=int,
default=3,
help='Number of iterations to run.') help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true', parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface') help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -3,7 +3,7 @@ import argparse ...@@ -3,7 +3,7 @@ import argparse
import json import json
import random import random
import time import time
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
...@@ -22,15 +22,10 @@ def sample_requests( ...@@ -22,15 +22,10 @@ def sample_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [(data["conversations"][0]["value"],
(data["conversations"][0]["value"], data["conversations"][1]["value"]) data["conversations"][1]["value"]) for data in dataset]
for data in dataset
]
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
...@@ -63,6 +58,7 @@ def run_vllm( ...@@ -63,6 +58,7 @@ def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
tokenizer: str, tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
...@@ -72,6 +68,7 @@ def run_vllm( ...@@ -72,6 +68,7 @@ def run_vllm(
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
seed=seed, seed=seed,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -111,8 +108,8 @@ def run_hf( ...@@ -111,8 +108,8 @@ def run_hf(
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model, llm = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
# To enable padding in the HF backend. # To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
...@@ -132,13 +129,14 @@ def run_hf( ...@@ -132,13 +129,14 @@ def run_hf(
if len(batch) < max_batch_size and i != len(requests) - 1: if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch. # Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1] _, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max( if (max(max_prompt_len, next_prompt_len) +
max_output_len, next_output_len)) <= 2048: max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch. # We can add more requests to the batch.
continue continue
# Generate the sequences. # Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=not use_beam_search,
...@@ -165,44 +163,58 @@ def main(args: argparse.Namespace): ...@@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(requests, args.model, args.tokenizer,
requests, args.model, args.tokenizer, args.tensor_parallel_size, args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code) args.seed, args.n, args.use_beam_search,
args.trust_remote_code)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf( elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
requests, args.model, tokenizer, args.n, args.use_beam_search, args.use_beam_search, args.hf_max_batch_size,
args.hf_max_batch_size, args.trust_remote_code) args.trust_remote_code)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum( total_num_tokens = sum(prompt_len + output_len
prompt_len + output_len for _, prompt_len, output_len in requests)
for _, prompt_len, output_len in requests
)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s") f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"], parser.add_argument("--backend",
type=str,
choices=["vllm", "hf"],
default="vllm") default="vllm")
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1, parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000, parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.") help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None, parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
...@@ -215,6 +227,8 @@ if __name__ == "__main__": ...@@ -215,6 +227,8 @@ if __name__ == "__main__":
elif args.backend == "hf": elif args.backend == "hf":
if args.hf_max_batch_size is None: if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
#include <torch/extension.h>
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
}
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}
This diff is collapsed.
...@@ -146,6 +146,20 @@ activation_extension = CUDAExtension( ...@@ -146,6 +146,20 @@ activation_extension = CUDAExtension(
) )
ext_modules.append(activation_extension) ext_modules.append(activation_extension)
# Quantization kernels.
quantization_extension = CUDAExtension(
name="vllm.quantization_ops",
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(quantization_extension)
def get_path(*filepath) -> str: def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath) return os.path.join(ROOT_DIR, *filepath)
......
...@@ -43,6 +43,8 @@ class ModelConfig: ...@@ -43,6 +43,8 @@ class ModelConfig:
version. version.
max_model_len: Maximum length of a sequence (including prompt and max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model. output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
""" """
def __init__( def __init__(
...@@ -57,6 +59,7 @@ class ModelConfig: ...@@ -57,6 +59,7 @@ class ModelConfig:
seed: int, seed: int,
revision: Optional[str], revision: Optional[str],
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -66,11 +69,13 @@ class ModelConfig: ...@@ -66,11 +69,13 @@ class ModelConfig:
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code, revision) self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_quantization()
self.max_model_len = None self.max_model_len = None
if max_model_len is not None: if max_model_len is not None:
derived_max_model_len = self.get_max_model_len() derived_max_model_len = self.get_max_model_len()
...@@ -100,6 +105,17 @@ class ModelConfig: ...@@ -100,6 +105,17 @@ class ModelConfig:
"either 'auto' or 'slow'.") "either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
if self.quantization is None:
return
quantization = self.quantization.lower()
if quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization: {self.quantization}. Must be one of "
f"{supported_quantization}.")
self.quantization = quantization
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
......
...@@ -29,6 +29,7 @@ class EngineArgs: ...@@ -29,6 +29,7 @@ class EngineArgs:
max_num_seqs: int = 256 max_num_seqs: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
quantization: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -88,7 +89,6 @@ class EngineArgs: ...@@ -88,7 +89,6 @@ class EngineArgs:
'a numpy cache to speed up the loading. ' 'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
'which is mainly for profiling.') 'which is mainly for profiling.')
# TODO(woosuk): Support FP32.
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
...@@ -150,6 +150,13 @@ class EngineArgs: ...@@ -150,6 +150,13 @@ class EngineArgs:
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser return parser
@classmethod @classmethod
...@@ -163,12 +170,11 @@ class EngineArgs: ...@@ -163,12 +170,11 @@ class EngineArgs:
def create_engine_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision,
self.max_model_len) self.max_model_len, self.quantization)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)
......
...@@ -80,6 +80,7 @@ class LLMEngine: ...@@ -80,6 +80,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
......
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
ColumnParallelLinear, RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)
...@@ -8,7 +8,8 @@ from transformers import PretrainedConfig ...@@ -8,7 +8,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
...@@ -30,6 +31,11 @@ _MODEL_REGISTRY = { ...@@ -30,6 +31,11 @@ _MODEL_REGISTRY = {
"RWForCausalLM": FalconForCausalLM, "RWForCausalLM": FalconForCausalLM,
} }
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
]
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
...@@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
quant_config = None
if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config) if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
......
...@@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul ...@@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.quantized_linear import ParallelLinear
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -55,18 +57,21 @@ class LlamaMLP(nn.Module): ...@@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
): quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, self.gate_up_proj = ParallelLinear.column(hidden_size,
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False,
self.down_proj = RowParallelLinear(intermediate_size, quant_config=quant_config)
hidden_size, self.down_proj = ParallelLinear.row(intermediate_size,
bias=False, hidden_size,
input_is_parallel=True, bias=False,
perform_initialization=False) input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -87,7 +92,8 @@ class LlamaAttention(nn.Module): ...@@ -87,7 +92,8 @@ class LlamaAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
): quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -103,20 +109,22 @@ class LlamaAttention(nn.Module): ...@@ -103,20 +109,22 @@ class LlamaAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ParallelLinear.column(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * (self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False, perform_initialization=False,
quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = ParallelLinear.row(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim, self.head_dim,
...@@ -144,7 +152,11 @@ class LlamaAttention(nn.Module): ...@@ -144,7 +152,11 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
...@@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module): ...@@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module): ...@@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module):
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -205,7 +223,8 @@ class LlamaModel(nn.Module): ...@@ -205,7 +223,8 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False) vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -237,16 +256,23 @@ class LlamaModel(nn.Module): ...@@ -237,16 +256,23 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = LlamaModel(config) self.quant_config = quant_config
self.model = LlamaModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, # NOTE: The LM head is not quantized.
vocab_size, self.lm_head = ParallelLinear.column(config.hidden_size,
bias=False, vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module): ...@@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_layers = []
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight" _row_parallel_layers = ["o_proj", "down_proj"]
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
...@@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module): ...@@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = loaded_weight.T
is_attention_weight = False is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs: for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size * shard_size * tensor_model_parallel_rank:shard_size *
...@@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size * shard_size * tensor_model_parallel_rank:shard_size *
...@@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name: if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight, load_padded_tensor_parallel_vocab(param, loaded_weight,
...@@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module):
continue continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, column_parallel_weights,
self._row_parallel_weights, row_parallel_weights,
tensor_model_parallel_rank) tensor_model_parallel_rank)
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
) )
from .random import get_cuda_rng_tracker
from .utils import ( from .utils import (
divide, divide,
VocabUtility, VocabUtility,
...@@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): ...@@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
maybe_copy(attribute) maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False,
*, params_dtype=None):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=torch.float,
requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module): class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
...@@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module):
use_cpu_initialization: bool=False, use_cpu_initialization: bool=False,
perform_initialization: bool=True): perform_initialization: bool=True):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
...@@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
# Allocate weights and initialize. self.weight = Parameter(torch.empty(
if use_cpu_initialization: self.num_embeddings_per_partition, self.embedding_dim,
self.weight = Parameter(torch.empty( device=torch.cuda.current_device(), dtype=params_dtype))
self.num_embeddings_per_partition, self.embedding_dim,
dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
...@@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module):
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
quant_config=None,
): ):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
...@@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size) self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. self.create_weights(params_dtype)
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias: if bias:
if use_cpu_initialization: self.bias = Parameter(torch.empty(
self.bias = Parameter(torch.empty( self.output_size_per_partition,
self.output_size_per_partition, dtype=params_dtype)) device=torch.cuda.current_device(),
else: dtype=params_dtype))
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
...@@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_): def forward(self, input_):
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
...@@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
...@@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module):
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
reduce_results=True, reduce_results=True,
quant_config=None,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
...@@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module): ...@@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size) self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias: if bias:
if use_cpu_initialization: self.bias = Parameter(torch.empty(
self.bias = Parameter(torch.empty(self.output_size, self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.weight_t = self.weight.t()
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_): def forward(self, input_):
"""Forward of RowParallelLinear """Forward of RowParallelLinear
...@@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1: if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else: else:
......
from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
}
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quant_class",
]
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
@classmethod
def get_name(cls) -> str:
return "awq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros"]
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
from typing import Any, Dict, List
import torch
class QuantizationConfig:
@classmethod
def get_name(cls) -> str:
"""Name of the quantization method."""
raise NotImplementedError
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_packed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
"""
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
raise NotImplementedError
...@@ -4,7 +4,7 @@ import glob ...@@ -4,7 +4,7 @@ import glob
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open from safetensors.torch import load_file, save_file, safe_open
...@@ -13,6 +13,8 @@ import torch ...@@ -13,6 +13,8 @@ import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.quantization_utils import get_quant_class
from vllm.model_executor.quantization_utils.base import QuantizationConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,7 +46,7 @@ def _shared_pointers(tensors): ...@@ -44,7 +46,7 @@ def _shared_pointers(tensors):
def convert_bin_to_safetensor_file( def convert_bin_to_safetensor_file(
pt_filename: str, pt_filename: str,
sf_filename: str, sf_filename: str,
): ) -> None:
loaded = torch.load(pt_filename, map_location="cpu") loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
...@@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file( ...@@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file(
raise RuntimeError(f"The output tensors do not match for key {k}") raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
quantization: str,
model_name_or_path: str,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.json",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quant_class(quantization)
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}")
if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
return quant_cls.from_config(config)
def prepare_hf_model_weights( def prepare_hf_model_weights(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_safetensors: bool = False, use_safetensors: bool = False,
fall_back_to_pt: bool = True, fall_back_to_pt: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
): ) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensors else "*.bin" if use_safetensors:
allow_patterns = ["*.safetensors"]
else:
# Some quantized models use .pt files for storing the weights.
allow_patterns = ["*.bin", "*.pt"]
if not is_local: if not is_local:
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
...@@ -99,7 +140,9 @@ def prepare_hf_model_weights( ...@@ -99,7 +140,9 @@ def prepare_hf_model_weights(
revision=revision) revision=revision)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors: if not use_safetensors:
hf_weights_files = [ hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin") x for x in hf_weights_files if not x.endswith("training_args.bin")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment