Unverified Commit 53ec0b79 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248)

* feat(fp8): add support for fbgemm

* allow loading fp8 weights directly

* update outlines

* fix makefile

* build fbgemm

* avoid circular import and fix dockerfile

* add default dtype

* refactored weights loader

* fix auto conversion

* fix quantization config parsing

* force new nccl on install

* missing get_weights implementation

* increase timeout
parent e5c1d6d6
from functools import lru_cache
from text_generation_server.utils.dist import RANK
@lru_cache(10)
def log_once(log, msg: str):
log(msg)
def log_once(log, msg: str, master=True):
if master:
log_master(log, msg)
else:
log(msg)
def log_master(log, msg: str):
if RANK == 0:
log(msg)
......@@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
)
# TODO: Split this config to have a single config type per quant method
@dataclass
class _QuantizerConfig:
bits: int
......@@ -21,6 +22,11 @@ class _QuantizerConfig:
sym: bool
@dataclass
class _FP8QuantizerConfig:
activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
......@@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision):
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
......@@ -99,6 +112,12 @@ def get_loader(
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
......@@ -127,18 +146,28 @@ def get_loader(
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Weight
return DefaultWeightsLoader(Fp8Weight)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
elif quantize is None:
return DefaultWeightsLoader(UnquantizedWeight)
elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
else:
raise ValueError(f"Unknown quantization method: {quantize}")
import torch
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from typing import Dict, List, Optional, Union, Type
from safetensors import safe_open
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
......@@ -84,7 +84,7 @@ class Weight(ABC):
@dataclass
class UnquantizedWeight:
class UnquantizedWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
......@@ -99,7 +99,7 @@ class UnquantizedWeight:
class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class):
def __init__(self, weight_class: Type[UnquantizedWeight]):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
......@@ -208,20 +208,29 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True):
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
# as well. FP8 uses torch.float8_e4m3fn
if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_partial_sharded(self, tensor_name: str, dim: int):
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
......@@ -241,12 +250,16 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32):
# FP8 uses torch.float8_e4m3fn.
if (
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
......@@ -255,10 +268,14 @@ class Weights:
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
self,
tensor_name: str,
dim: int,
block_sizes: Union[int, List[int]],
to_dtype=True,
) -> torch.Tensor:
"""
Get a shard from a tensor that packs multiple tensors.
......@@ -304,7 +321,16 @@ class Weights:
tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment