Unverified Commit 8511669c authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Move quantized weight handling out of the `Weights` class (#2194)

Quantized weights were loaded in the `Weights` class, but this was
getting quite unwieldy, where every higher level method to load weights
was a long conditional to cover all the different quantizers.

This change moves loading of quantized weights out of the `Weights`
class. This is done by defining a simple `WeightsLoader` interface
that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`,
and `MarlinWeightsLoader`. These implementations are in the quantizers'
respective modules. The `Weights` class provides the low-level load
operations (such as loading tensors or sharded tensors), but delegates
loads that need quantizer-specific weight processing to a loader. The
loaders still use the low-level functionality provided by `Weights`.

I initially tried making a hierarchy where a class like `GPTQWeights`
would inherit from `Weights`. But it is not very flexible (e.g. does
not work well with the new weight storage mock used in tests) and
the implicit indirections made the code harder to follow.
parent 4c976fb4
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup: class ProcessGroup:
...@@ -42,7 +43,12 @@ class Weights: ...@@ -42,7 +43,12 @@ class Weights:
def test_weight_hub_files_offline_error(): def test_weight_hub_files_offline_error():
vocab_size = 17 vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights) embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size) input_ids = torch.arange(vocab_size)
......
import pytest import pytest
import torch import torch
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import (
from text_generation_server.layers.gptq import GPTQWeight DefaultWeightsLoader,
from text_generation_server.layers.exl2 import Exl2Weight Weights,
from text_generation_server.layers.marlin import MarlinWeight WeightsLoader,
)
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from pathlib import Path from pathlib import Path
@pytest.fixture
def gptq_weights_loader():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="gptq",
quantize="gptq",
sym=True,
)
@pytest.fixture
def gptq_weights_loader_awq():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="awq",
quantize="awq",
sym=True,
)
@pytest.fixture
def marlin_weights_loader():
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
dummy_file_system = { dummy_file_system = {
"test_weights": { "test_weights": {
"layer.0.weight": torch.tensor( "layer.0.weight": torch.tensor(
...@@ -58,7 +92,7 @@ dummy_file_system = { ...@@ -58,7 +92,7 @@ dummy_file_system = {
dtype=torch.float32, dtype=torch.float32,
), ),
}, },
"test_get_multi_weights_row": { "test_get_weights_row": {
"weight.weight": torch.tensor( "weight.weight": torch.tensor(
[ [
[1, 2], [1, 2],
...@@ -101,7 +135,7 @@ dummy_file_system = { ...@@ -101,7 +135,7 @@ dummy_file_system = {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
}, },
"test_get_multi_weights_row_gptq": { "test_get_weights_row_gptq": {
"weight.qweight": torch.tensor( "weight.qweight": torch.tensor(
[ [
[1, 2], [1, 2],
...@@ -200,7 +234,7 @@ dummy_file_system = { ...@@ -200,7 +234,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_exl2": { "test_get_weights_row_exl2": {
"weight.q_weight": torch.tensor( "weight.q_weight": torch.tensor(
[ [
[1, 2], [1, 2],
...@@ -245,7 +279,7 @@ dummy_file_system = { ...@@ -245,7 +279,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_marlin": { "test_get_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
...@@ -308,6 +342,7 @@ class MockWeights(Weights): ...@@ -308,6 +342,7 @@ class MockWeights(Weights):
dummy_fs, dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
weights_loader: Optional[WeightsLoader] = None,
): ):
routing = {} routing = {}
self.dummy_fs = dummy_fs self.dummy_fs = dummy_fs
...@@ -327,6 +362,9 @@ class MockWeights(Weights): ...@@ -327,6 +362,9 @@ class MockWeights(Weights):
self.dtype = dtype self.dtype = dtype
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader
)
self._handles = {} self._handles = {}
def _get_handle(self, filename: Union[Path, str]): def _get_handle(self, filename: Union[Path, str]):
...@@ -412,12 +450,10 @@ def test_get_weights_col_packed(): ...@@ -412,12 +450,10 @@ def test_get_weights_col_packed():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
...@@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size(): ...@@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 2 block_sizes = 2
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
...@@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr(): ...@@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = [1, 1] block_sizes = [1, 1]
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
...@@ -519,11 +551,9 @@ def test_get_multi_weights_col(): ...@@ -519,11 +551,9 @@ def test_get_multi_weights_col():
) )
prefixes = ["weight", "weight"] prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
...@@ -545,10 +575,10 @@ def test_get_multi_weights_col(): ...@@ -545,10 +575,10 @@ def test_get_multi_weights_col():
) )
def test_get_multi_weights_row(): def test_get_weights_row():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row", "test_get_weights_row",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
...@@ -557,11 +587,9 @@ def test_get_multi_weights_row(): ...@@ -557,11 +587,9 @@ def test_get_multi_weights_row():
) )
prefix = "weight" prefix = "weight"
quantize = None
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
assert torch.allclose( assert torch.allclose(
...@@ -576,7 +604,7 @@ def test_get_multi_weights_row(): ...@@ -576,7 +604,7 @@ def test_get_multi_weights_row():
# test_get_weights_col # test_get_weights_col
def test_get_weights_col_awq(): def test_get_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
...@@ -585,14 +613,13 @@ def test_get_weights_col_awq(): ...@@ -585,14 +613,13 @@ def test_get_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
...@@ -617,7 +644,7 @@ def test_get_weights_col_awq(): ...@@ -617,7 +644,7 @@ def test_get_weights_col_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq(): def test_get_weights_col_gtpq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
...@@ -626,14 +653,13 @@ def test_get_weights_col_gtpq(): ...@@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
...@@ -664,14 +690,13 @@ def test_get_weights_col_exl2(): ...@@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
scaled_scale_max = 0.3906 * 256 scaled_scale_max = 0.3906 * 256
...@@ -692,7 +717,7 @@ def test_get_weights_col_exl2(): ...@@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin(): def test_get_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_marlin", "test_get_weights_col_marlin",
...@@ -701,14 +726,13 @@ def test_get_weights_col_marlin(): ...@@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
...@@ -723,7 +747,7 @@ def test_get_weights_col_marlin(): ...@@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
# test_get_weights_col_packed # test_get_weights_col_packed
def test_get_weights_col_packed_awq(): def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
...@@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq(): ...@@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
...@@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2(): ...@@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
...@@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2(): ...@@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq(): def test_get_weights_col_packed_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
...@@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq(): ...@@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
...@@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq(): ...@@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin(): def test_get_weights_col_packed_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_marlin", "test_get_weights_col_packed_marlin",
...@@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin(): ...@@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
...@@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin(): ...@@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
# test_get_multi_weights_col # test_get_multi_weights_col
def test_get_multi_weights_col_awq(): def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
...@@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq(): ...@@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
...@@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2(): ...@@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
try: try:
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
except ValueError as e: except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2" assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq(): def test_get_multi_weights_col_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
...@@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq(): ...@@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
...@@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq(): ...@@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin(): def test_get_multi_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_marlin", "test_get_multi_weights_col_marlin",
...@@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin(): ...@@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
...@@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin(): ...@@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row # test_get_weights_row
def test_get_multi_weights_row_awq(): def test_get_weights_row_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
...@@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq(): ...@@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2(): def test_get_weights_row_exl2():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_exl2", "test_get_weights_row_exl2",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
print(w) print(w)
...@@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2(): ...@@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_gptq(): def test_get_weights_row_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
...@@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq(): ...@@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_marlin(): def test_get_weights_row_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_marlin", "test_get_weights_row_marlin",
], ],
device="cpu", device="cpu",
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
......
import torch import torch
from typing import List, Union
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.weights import WeightsLoader, Weights
@dataclass @dataclass
class Exl2Weight: class Exl2Weight:
...@@ -21,3 +24,60 @@ class Exl2Weight: ...@@ -21,3 +24,60 @@ class Exl2Weight:
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self.q_weight.device return self.q_weight.device
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
import os import os
from typing import Optional from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
SYSTEM, SYSTEM,
) )
from text_generation_server.utils.log import log_once
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass @dataclass
...@@ -69,3 +63,341 @@ elif CAN_EXLLAMA: ...@@ -69,3 +63,341 @@ elif CAN_EXLLAMA:
pass pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
try:
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
self.sym = False
self.quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
pass
...@@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear ...@@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.utils.weights import DefaultWeightsLoader
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
...@@ -891,6 +893,7 @@ def quantize( ...@@ -891,6 +893,7 @@ def quantize(
dtype=torch.float16, dtype=torch.float16,
process_group=process_group, process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]}, aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(),
) )
hooks = [] hooks = []
for name, module in model.named_modules(): for name, module in model.named_modules():
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
try: try:
...@@ -24,16 +24,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] ...@@ -24,16 +24,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and marlin_kernels is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize == "gptq"
and gptq_params.quant_method == "gptq" and quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym and sym
) )
......
...@@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer): ...@@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
except: except:
# ...otherwise they are quantized. # ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1 should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1: elif weights.process_group.size() > 1:
try: try:
...@@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool): def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up( weight = weights.get_weights_col_packed_gate_up(prefix)
prefix, quantize=config.quantize
)
if bias: if bias:
raise NotImplementedError("packed_gate_up only implemented without bias") raise NotImplementedError("packed_gate_up only implemented without bias")
else: else:
...@@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
prefix, prefix,
quantize=config.quantize,
num_heads=num_heads, num_heads=num_heads,
num_key_value_heads=num_key_value_heads, num_key_value_heads=num_key_value_heads,
) )
...@@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
...@@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
if config.quantize == "exl2": if config.quantize == "exl2":
linears = [] linears = []
for prefix in prefixes: for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize)) linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears) linear = LayerConcat(linears)
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(prefixes, dim=dim)
prefixes, quantize=config.quantize, dim=dim
)
if bias: if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim) bias = torch.cat(b, dim=dim)
...@@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer): ...@@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
......
...@@ -20,6 +20,7 @@ from text_generation_server.utils import ( ...@@ -20,6 +20,7 @@ from text_generation_server.utils import (
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
...@@ -546,12 +547,17 @@ class CausalLM(Model): ...@@ -546,12 +547,17 @@ class CausalLM(Model):
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
......
...@@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights): ...@@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
......
...@@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): ...@@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
......
...@@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): ...@@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
......
...@@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights): ...@@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
# Weights # Weights
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn", f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads, config.num_attention_heads,
config.num_attention_heads, config.num_attention_heads,
) )
...@@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices.""" """load_row, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
...@@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_col(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices.""" """load_col, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=1)
[prefix], quantize=config.quantize, dim=1
)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
......
...@@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights): ...@@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
......
...@@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import ( ...@@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
...@@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, torch.Tensor): if isinstance(weight, torch.Tensor):
# Only on non quantized versions # Only on non quantized versions
weight = ( weight = (
......
...@@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights): ...@@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
......
...@@ -23,7 +23,7 @@ from text_generation_server.layers.attention import ( ...@@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
......
...@@ -17,6 +17,7 @@ from text_generation_server.layers import ( ...@@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
...@@ -81,11 +82,13 @@ def _load_multi_mqa_gptq( ...@@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
gptq_params = weights._get_gptq_params() loader = weights.weights_loader
if gptq_params.quant_method == "gptq": assert isinstance(loader, GPTQWeightsLoader)
loader._get_gptq_params(weights)
if loader.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
elif gptq_params.quant_method == "awq": elif loader.quant_method == "awq":
g_idx = None g_idx = None
from text_generation_server.layers.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
...@@ -100,8 +103,8 @@ def _load_multi_mqa_gptq( ...@@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=gptq_params.bits, bits=loader.bits,
groupsize=gptq_params.groupsize, groupsize=loader.groupsize,
use_exllama=HAS_EXLLAMA, use_exllama=HAS_EXLLAMA,
) )
...@@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool): ...@@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=0)
[prefix], quantize=config.quantize, dim=0
)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
...@@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else: else:
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
......
...@@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights): ...@@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
......
...@@ -50,6 +50,7 @@ from text_generation_server.models.globals import ( ...@@ -50,6 +50,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
...@@ -881,12 +882,16 @@ class FlashCausalLM(Model): ...@@ -881,12 +882,16 @@ class FlashCausalLM(Model):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
......
...@@ -23,6 +23,7 @@ from text_generation_server.utils import ( ...@@ -23,6 +23,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.quantization import get_loader
class IDEFICSSharded(IdeficsCausalLM): class IDEFICSSharded(IdeficsCausalLM):
...@@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM): ...@@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
...@@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM): ...@@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device=device, device=device,
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
weights_loader=weights_loader,
) )
model = IdeficsForVisionText2Text(config, weights) model = IdeficsForVisionText2Text(config, weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment