Unverified Commit 3eeb148f authored by Elsa Granger's avatar Elsa Granger Committed by GitHub
Browse files

[Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (#6871)

parent b1366a95
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -72,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -72,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights( def create_weights(
self, self,
...@@ -139,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -139,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias) bias=bias)
return apply_fp8_linear(input=x, return apply_fp8_linear(
weight=layer.weight, input=x,
weight_scale=layer.weight_scale, weight=layer.weight,
input_scale=None, weight_scale=layer.weight_scale,
input_scale_ub=layer.input_scale_ub, input_scale=None,
bias=bias, input_scale_ub=layer.input_scale_ub,
cutlass_fp8_supported=True, bias=bias,
use_per_token_if_dynamic=True) cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment