Unverified Commit a3a73ab0 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
parent 8674f988
...@@ -153,15 +153,13 @@ if __name__ == '__main__': ...@@ -153,15 +153,13 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='enforce eager mode and disable CUDA graph') help='enforce eager mode and disable CUDA graph')
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default='auto', default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,
......
...@@ -323,15 +323,13 @@ if __name__ == "__main__": ...@@ -323,15 +323,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="enforce eager execution") help="enforce eager execution")
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=["auto", "fp8"], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto", default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,
......
...@@ -183,13 +183,11 @@ if __name__ == '__main__': ...@@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
choices=["auto", "fp8"], choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto", default="auto",
help= help="Data type for kv cache storage. If 'auto', will use model "
'Data type for kv cache storage. If "auto", will use model data type. ' "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -16,22 +16,35 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" ...@@ -16,22 +16,35 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
MODELS = [ MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8", "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
"meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct",
] ]
EXPECTED_STRS_MAP = { EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
"auto": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
], ],
"meta-llama/Meta-Llama-3-8B-Instruct": [ "fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system made up of several basic components that work together to enable it to',
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
]
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"auto": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
...@@ -41,6 +54,17 @@ EXPECTED_STRS_MAP = { ...@@ -41,6 +54,17 @@ EXPECTED_STRS_MAP = {
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
], ],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
]
},
} }
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
...@@ -52,14 +76,14 @@ fp8_not_supported = (capability < ...@@ -52,14 +76,14 @@ fp8_not_supported = (capability <
@pytest.mark.skipif(fp8_not_supported, @pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.") reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS) @pytest.mark.parametrize("model_name", MODELS)
def test_models( @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
example_prompts, def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
model_name,
) -> None:
model = LLM(model=model_name, model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
quantization="fp8") quantization="fp8",
kv_cache_dtype=kv_cache_dtype)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [ formatted_prompts = [
...@@ -81,8 +105,8 @@ def test_models( ...@@ -81,8 +105,8 @@ def test_models(
generations.append(outputs[0].outputs[0].text) generations.append(outputs[0].outputs[0].text)
del model del model
print(generations) print(model_name, kv_cache_dtype, generations)
expected_strs = EXPECTED_STRS_MAP[model_name] expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
generated_str = generations[i] generated_str = generations[i]
expected_str = expected_strs[i] expected_str = expected_strs[i]
......
...@@ -7,6 +7,8 @@ import torch.nn as nn ...@@ -7,6 +7,8 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class Attention(nn.Module): class Attention(nn.Module):
...@@ -30,6 +32,7 @@ class Attention(nn.Module): ...@@ -30,6 +32,7 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if cache_config is not None: if cache_config is not None:
...@@ -40,6 +43,27 @@ class Attention(nn.Module): ...@@ -40,6 +43,27 @@ class Attention(nn.Module):
block_size = 16 block_size = 16
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
# The default kv_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
...@@ -57,10 +81,9 @@ class Attention(nn.Module): ...@@ -57,10 +81,9 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata, return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale) self._kv_scale)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
......
...@@ -355,14 +355,12 @@ class CacheConfig: ...@@ -355,14 +355,12 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None: def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
elif self.cache_dtype == "fp8": elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info( logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU " "Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. " "memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling " "Meanwhile, it may cause accuracy drop without a proper "
"factors. FP8_E5M2 (without scaling) is only supported on " "scaling factor")
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
else: else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
......
...@@ -191,12 +191,11 @@ class EngineArgs: ...@@ -191,12 +191,11 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--kv-cache-dtype', '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype, default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=nullable_str, type=nullable_str,
......
...@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops ...@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig): ...@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
activation_scheme=activation_scheme) activation_scheme=activation_scheme)
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
if isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.narrow(output, 0, 0, x.shape[0]) return torch.narrow(output, 0, 0, x.shape[0])
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool: def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1 assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
......
...@@ -268,7 +268,8 @@ class ArcticAttention(nn.Module): ...@@ -268,7 +268,8 @@ class ArcticAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module): ...@@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
quant_config=quant_config)
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module): ...@@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -111,7 +111,8 @@ class BloomAttention(nn.Module): ...@@ -111,7 +111,8 @@ class BloomAttention(nn.Module):
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -86,13 +86,12 @@ class GLMAttention(nn.Module): ...@@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -177,13 +177,12 @@ class CohereAttention(nn.Module): ...@@ -177,13 +177,12 @@ class CohereAttention(nn.Module):
rope_scaling=self.rope_scaling, rope_scaling=self.rope_scaling,
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads, self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim), self.head_dim),
......
...@@ -218,13 +218,12 @@ class DbrxAttention(nn.Module): ...@@ -218,13 +218,12 @@ class DbrxAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module): ...@@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -153,7 +153,8 @@ class FalconAttention(nn.Module): ...@@ -153,7 +153,8 @@ class FalconAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
...@@ -165,13 +166,15 @@ class FalconAttention(nn.Module): ...@@ -165,13 +166,15 @@ class FalconAttention(nn.Module):
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
quant_config=quant_config)
else: else:
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.inv_norm_factor, scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -157,7 +157,8 @@ class GemmaAttention(nn.Module): ...@@ -157,7 +157,8 @@ class GemmaAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -75,7 +75,8 @@ class GPT2Attention(nn.Module): ...@@ -75,7 +75,8 @@ class GPT2Attention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module): ...@@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -88,7 +88,8 @@ class GPTJAttention(nn.Module): ...@@ -88,7 +88,8 @@ class GPTJAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment